#' =============================================================================
#' FILE: 04_causal_forest.R
#' DESCRIPTION:
#'   Estimates heterogeneous treatment effects of cues using causal forests.
#'   Produces variable importance tables and figures for treatment effect
#'   distributions.
#'
#' Creates table 1 with variable importance and figure 6 with treatment effect.
#'
#' PACKAGES REQUIRED: pacman, tidyverse, grf
#'
#' OUTPUTS:
#'   - 03_figures/fg_06.jpg
#' =============================================================================

# Packages and dataset ---------------------------------------------------------

# Required packages
if (!require("pacman")) install.packages("pacman")
pacman::p_load(tidyverse,
               grf)

# Data
df <- readRDS("04_outputs/clean_dataset.rds")
df_e <- readRDS("04_outputs/clean_dataset_reshaped.rds")

# Identity Strength Measures ---------------------------------------------------

# Movement-based Identities
df_e$fuji_st <- ifelse(df_e$cha_id == "Fujimorismo", df_e$cha_st, 0)
df_e$anti_st <- ifelse(df_e$cha_id == "Anti-Fujimorismo", df_e$cha_st, 0)
# Ideological Identities
df_e$right_st <- ifelse(df_e$ido_id == "right", df_e$ido_st, 0)
df_e$left_st <- ifelse(df_e$ido_id == "left", df_e$ido_st, 0)
df_e$center_st <- ifelse(df_e$ido_id == "center", df_e$ido_st, 0)
# Partisan Identities
df_e$pl_st <- ifelse(df_e$party_id == "PL", df_e$pid_st, 0)
df_e$fp_st <- ifelse(df_e$party_id == "FP", df_e$pid_st, 0)
df_e$ap_st <- ifelse(df_e$party_id == "AP", df_e$pid_st, 0)
df_e$ot_st <- ifelse(df_e$party_id == "Other", df_e$pid_st, 0)

# Fujimorismo and anti-fujimorismo cues as categorical -------------------------

# Cue 1: Fuji
# Cue 2: Anti
# Cue 3: Anti
# Cue 4: Fuji
# Cue 5: Anti
# Cue 6: Fuji

# SUBSET DATASET
## Treatment Fuji
df_f <- df_e %>%
filter(exp == 1 |
         exp == 4 |
         exp == 6)
## Treatment Anti
df_a <- df_e %>%
  filter(exp == 2 |
           exp == 3 |
           exp == 5)

# Variables for Analysis -------------------------------------------------------

# Fuji Treatment
Y_f <- as.numeric(df_f$position) # Dependent variable
W_f <- ifelse(df_f$cond == 'fuji', 1, 0)
X_f <- model.matrix(~ exp +
                      cha_id + #Original was with the categorical variables
                      party_id +
                      ido_id +
                      fuji_st +
                      anti_st +
                      right_st +
                      center_st +
                      left_st +
                      pl_st +
                      fp_st +
                      ap_st +
                      ot_st +
                      as.character(region) + # Original was without region, female, educ, and age. 
                      female +
                      educ +
                      in_pol + 
                      trust_p +
                      age +
                      as.character(order), 
                    data = df_f)[,-1]
# Anti Treatment
Y_a <- as.numeric(df_a$position) # Dependent variable
W_a <- ifelse(df_a$cond == 'anti', 1, 0)
X_a <- model.matrix(~ exp +
                      cha_id +
                      party_id +
                      ido_id +
                      fuji_st +
                      anti_st +
                      right_st +
                      center_st +
                      left_st +
                      as.character(region) +
                      female +
                      educ +
                      in_pol + 
                      trust_p +
                      age +
                      as.character(order), 
                    data = df_a)[, -1]

# Causal Forest ----------------------------------------------------------------

# Set seed and split dataset
set.seed(11142023)
cases_f <- sample(seq_len(nrow(df_f)), round(nrow(df_f) * .75))
cases_a <- sample(seq_len(nrow(df_a)), round(nrow(df_a) * .75))

# Split X, Y, and W separately by cases indices 
## Fuji
train_Xf <- X_f[cases_f, ] 
test_Xf <- X_f[-cases_f, ] 
train_Yf = Y_f[cases_f] 
test_Yf = Y_f[-cases_f]
train_Wf = W_f[cases_f]
test_Wf = W_f[-cases_f]
## Anti
train_Xa <- X_a[cases_a, ] 
test_Xa <- X_a[-cases_a, ] 
train_Ya = Y_a[cases_a] 
test_Ya = Y_a[-cases_a]
train_Wa = W_a[cases_a]
test_Wa = W_a[-cases_a]

# Train Model
## Fuji
m1_f <- grf::causal_forest(
  #Define Y, X, and W
  Y = train_Yf, 
  X = train_Xf,
  W = train_Wf,
  num.trees = 5000, # Run more trees than default 
  seed = 1988 # grf allows you to set a seed internally 
)
## anti
m1_a <- grf::causal_forest(
  #Define Y, X, and W
  Y = train_Ya, 
  X = train_Xa,
  W = train_Wa,
  num.trees = 5000, # Run more trees than default 
  seed = 1988 # grf allows you to set a seed internally 
)

# Make individual level treatment effect predictions on test set 
## Fuji
priority.cate.f <- predict(object = m1_f, 
                         newdata = test_Xf,
                         estimate.variance = TRUE # For confidence intervals
)

test_Xf <- as.data.frame(test_Xf)
test_Xf$preds <- priority.cate.f$predictions 
test_Xf$Y <- test_Yf
test_Xf$W <- test_Wf
## Anti
priority.cate.a <- predict(object = m1_a, 
                           newdata = test_Xa,
                           estimate.variance = TRUE # For confidence intervals
)

test_Xa <- as.data.frame(test_Xa)
test_Xa$preds <- priority.cate.a$predictions 
test_Xa$Y <- test_Ya
test_Xa$W <- test_Wa

# Figure 6: Distribution of treatment effect -----------------------------------

rbind(test_Xf %>%
        mutate(exp = "Fujimorista") %>% 
        select(preds, exp),
      test_Xa %>%
        mutate(exp = "Anti-Fujimorista")%>% 
        select(preds, exp)) %>% 
  mutate(exp = factor(exp, levels = c("Fujimorista", 
                                      "Anti-Fujimorista"))) %>% 
  ggplot(aes(x = preds, colour = exp, fill = exp)) +
    geom_density(
      alpha = 0.7, 
      adjust = 1.5
      ) +
  scale_color_manual(values = c("Black", "Gray")) +
  scale_fill_manual(values = c("transparent", "Gray")) +
  labs(x = "Predicted Effect", y = "Density") +
  scale_x_continuous(limits = c(-2.5, 2.5)) +
  geom_vline(xintercept = mean(train_Yf[train_Wf == 1]) - 
               mean(train_Yf[train_Wf == 0]), 
             linetype = "dashed", 
             colour =  "black") + 
  geom_vline(xintercept = mean(train_Ya[train_Wa == 1]) - 
               mean(train_Ya[train_Wa == 0]), 
             linetype = "dashed", 
             colour =  "gray") + 
  theme(legend.position = "bottom", 
        # legend.title = element_blank(),
        legend.key=element_rect(fill="white"),
        legend.box.background = element_rect(colour = "black"), #, fill = "transparent"
        legend.background = element_blank(), # If not it messes with the rectangle. 
        panel.background = element_rect(fill='white', colour='white'),
        panel.grid.major = element_line(colour = "grey95"),
        panel.grid.minor = element_line(colour = "grey95"),
        axis.ticks.y = element_blank(),
        axis.ticks.x = element_blank(),
        strip.background = element_rect(fill = "white")) +
  guides(colour = guide_legend(title.position = "top", title.hjust = 0.5, title = "Treatment Condition"),
         fill = guide_legend(title.position = "top", title.hjust = 0.5, title = "Treatment Condition"))

ggsave('03_figures/fg_06.jpg', 
       height = 6, 
       width = 9)

# Table 1: Variable importance -------------------------------------------------

# Fujimorista cues
fujimorista_tbl <- m1_f %>% 
  variable_importance() %>% 
  as.data.frame() %>% 
  mutate(
    Variable = colnames(m1_f$X.orig), 
    Importance = round(V1, 4)
  ) %>% 
  arrange(desc(Importance)) %>% 
  select(Variable, Importance)  %>% 
  head()

# Anti-Fujimorista cues
anti_fujimorista_tbl <- m1_a %>% 
  variable_importance() %>% 
  as.data.frame() %>% 
  mutate(
    Variable = colnames(m1_a$X.orig), 
    Importance = round(V1, 4)
  ) %>% 
  arrange(desc(Importance)) %>% 
  select(Variable, Importance) %>% 
  head()

# Table 1
print(bind_cols(fujimorista_tbl, anti_fujimorista_tbl))